"""OpenTelemetry Milvus DB instrumentation"""

import logging
import pymilvus

from typing import Collection

from opentelemetry.instrumentation.milvus.config import Config
from opentelemetry.metrics import get_meter
from opentelemetry.trace import get_tracer
from wrapt import wrap_function_wrapper

from opentelemetry.semconv_ai import Meters
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap

from opentelemetry.instrumentation.milvus.wrapper import _wrap
from opentelemetry.instrumentation.milvus.version import __version__
from opentelemetry.instrumentation.milvus.utils import is_metrics_enabled

logger = logging.getLogger(__name__)

_instruments = ("pymilvus >= 2.4.1",)

WRAPPED_METHODS = [
    {
        "package": pymilvus,
        "object": "MilvusClient",
        "method": "create_collection",
        "span_name": "milvus.create_collection"
    },
    {
        "package": pymilvus,
        "object": "MilvusClient",
        "method": "insert",
        "span_name": "milvus.insert"
    },
    {
        "package": pymilvus,
        "object": "MilvusClient",
        "method": "upsert",
        "span_name": "milvus.upsert"
    },
    {
        "package": pymilvus,
        "object": "MilvusClient",
        "method": "delete",
        "span_name": "milvus.delete"
    },
    {
        "package": pymilvus,
        "object": "MilvusClient",
        "method": "search",
        "span_name": "milvus.search"
    },
    {
        "package": pymilvus,
        "object": "MilvusClient",
        "method": "get",
        "span_name": "milvus.get"
    },
    {
        "package": pymilvus,
        "object": "MilvusClient",
        "method": "query",
        "span_name": "milvus.query"
    },
    {
        "package": pymilvus,
        "object": "MilvusClient",
        "method": "hybrid_search",
        "span_name": "milvus.hybrid_search"
    },
]


class MilvusInstrumentor(BaseInstrumentor):
    """An instrumentor for Milvus's client library."""

    def __init__(self, exception_logger=None):
        super().__init__()
        Config.exception_logger = exception_logger

    def instrumentation_dependencies(self) -> Collection[str]:
        return _instruments

    def _instrument(self, **kwargs):
        # Set default values in case metrics are disabled
        query_duration_metric = None
        distance_metric = None
        insert_units_metric = None
        upsert_units_metric = None
        delete_units_metric = None

        if is_metrics_enabled():
            meter_provider = kwargs.get("meter_provider")
            meter = get_meter(__name__, __version__, meter_provider)

            query_duration_metric = meter.create_histogram(
                Meters.DB_QUERY_DURATION,
                "s",
                "Duration of query operations",
            )
            distance_metric = meter.create_histogram(
                Meters.DB_SEARCH_DISTANCE,
                "",
                "Distance between search query vector and matched vectors",
            )
            insert_units_metric = meter.create_counter(
                Meters.DB_USAGE_INSERT_UNITS,
                "",
                "Number of insert units consumed in serverless calls",
            )
            upsert_units_metric = meter.create_counter(
                Meters.DB_USAGE_UPSERT_UNITS,
                "",
                "Number of upsert units consumed in serverless calls",
            )
            delete_units_metric = meter.create_counter(
                Meters.DB_USAGE_DELETE_UNITS,
                "",
                "Number of delete units consumed in serverless calls",
            )

        tracer_provider = kwargs.get("tracer_provider")
        tracer = get_tracer(__name__, __version__, tracer_provider)

        for wrapped_method in WRAPPED_METHODS:
            wrap_package = wrapped_method.get("package")
            wrap_object = wrapped_method.get("object")
            wrap_method = wrapped_method.get("method")
            if getattr(wrap_package, wrap_object, None):
                wrap_function_wrapper(
                    wrap_package,
                    f"{wrap_object}.{wrap_method}",
                    _wrap(
                        tracer,
                        query_duration_metric,
                        distance_metric,
                        insert_units_metric,
                        upsert_units_metric,
                        delete_units_metric,
                        wrapped_method
                    ),
                )

    def _uninstrument(self, **kwargs):
        for wrapped_method in WRAPPED_METHODS:
            wrap_package = wrapped_method.get("package")
            wrap_object = wrapped_method.get("object")

            wrapped = getattr(wrap_package, wrap_object, None)
            if wrapped:
                unwrap(wrapped, wrapped_method.get("method"))
